package com.gitee.kamismile.stoneComEx.common.filter.server;

import com.gitee.kamismile.stone.commmon.base.ResultVO;
import com.gitee.kamismile.stone.commmon.util.JsonUtil;
import com.gitee.kamismile.stoneComEx.util.IpUtil;
import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.Bucket;
import io.github.bucket4j.ConsumptionProbe;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpStatus;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;


public class RateLimitReactiveFilter implements ReactiveFilter, Ordered {
    protected final Logger logger = LoggerFactory.getLogger(getClass());

    private final Map<String, Bucket> buckets = new ConcurrentHashMap<>();

    public Bucket createIPRateLimitBucket() {
        Bandwidth limit = Bandwidth.simple(3, Duration.ofSeconds(1));
        return Bucket.builder().addLimit(limit).build();
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
        String ip = IpUtil.getIpByServerRequest(exchange.getRequest());

        if (!buckets.containsKey(ip)) {
            buckets.put(ip, createIPRateLimitBucket());
        }

        Bucket bucket = buckets.get(ip);
        ConsumptionProbe consumptionProbe = bucket.tryConsumeAndReturnRemaining(1);
        if (!consumptionProbe.isConsumed()) {
          return  this.updateExceeded(exchange, consumptionProbe);
//            return Mono.empty();
        }

        return chain.filter(exchange);
    }

    private Mono<Void> updateExceeded(ServerWebExchange exchange, ConsumptionProbe remaining) {
        logger.info("{} limit", IpUtil.getIpByServerRequest(exchange.getRequest()));
        ResultVO errorMsg = new ResultVO();
        errorMsg.setCode(HttpStatus.BANDWIDTH_LIMIT_EXCEEDED.value()+"");
        errorMsg.setMessage(HttpStatus.BANDWIDTH_LIMIT_EXCEEDED.getReasonPhrase());
        exchange.getResponse().setStatusCode(HttpStatus.BANDWIDTH_LIMIT_EXCEEDED);
        Map<String, Object> result = new HashMap<>();
        result.put("code", errorMsg.getCode());
        result.put("message", errorMsg.getMessage());
        result.put("errors", Arrays.asList(errorMsg));

        byte[] bytes = JsonUtil.toJson(result).getBytes(StandardCharsets.UTF_8);
        DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(bytes);
        return exchange.getResponse().writeAndFlushWith(Flux.just(Flux.just(buffer)));
    }

    @Override
    public int getOrder() {
        return Integer.MAX_VALUE / 2;
    }
}
